﻿#include "pch.h"
#include "aes.h"
#include "aes_encryptor.h"

using namespace std;

unsigned int ReadFile(const char* inputFile, std::string& strBuffer)
{
	ifstream ifs(inputFile);
	if (!ifs.is_open())
	{
		return 0;
	}

	ifs.seekg(0, ios_base::end);
	std::streampos pos = ifs.tellg();
	ifs.seekg(0, ios_base::beg);

	unsigned int remain = pos % 16;
	if (remain != 0)
	{
		pos += (16 - remain);
	}

	strBuffer.assign((unsigned int)pos, 0);

	char* buffer = &*strBuffer.begin();
	ifs.read(buffer, pos);
	ifs.close();

	return (unsigned int)pos;
}

AesEncryptor::AesEncryptor(unsigned char* key)
{
    m_pEncryptor = new AES(key);
}


AesEncryptor::~AesEncryptor(void)
{
    delete m_pEncryptor;
}

void AesEncryptor::Byte2Hex(const unsigned char* src, int len, char* dest) {
    for (int i=0; i<len; ++i) {
        sprintf_s(dest + i * 2, 3, "%02X", src[i]);
    }
}

void AesEncryptor::Hex2Byte(const char* src, int len, unsigned char* dest) {
    int length = len / 2;
    for (int i=0; i<length; ++i) {
        dest[i] = Char2Int(src[i * 2]) * 16 + Char2Int(src[i * 2 + 1]);
    }
}

int AesEncryptor::Char2Int(char c) {
    if ('0' <= c && c <= '9') {
        return (c - '0');
    }
    else if ('a' <= c && c<= 'f') {
        return (c - 'a' + 10);
    }
    else if ('A' <= c && c<= 'F') {
        return (c - 'A' + 10);
    }
    return -1;
}

string AesEncryptor::EncryptString(string strInfor) 
{
    int nLength = strInfor.length();
    int spaceLength = 16 - (nLength % 16);
    unsigned char* pBuffer = new unsigned char[nLength + spaceLength];
    memset(pBuffer, '\0', nLength + spaceLength);
    memcpy_s(pBuffer, nLength + spaceLength, strInfor.c_str(), nLength);
    m_pEncryptor->Cipher(pBuffer, nLength + spaceLength);

    // 这里需要把得到的字符数组转换成十六进制字符串
    char* pOut = new char[2 * (nLength + spaceLength)];
    memset(pOut, '\0', 2 * (nLength + spaceLength));
    Byte2Hex(pBuffer, nLength + spaceLength, pOut);

    string retValue(pOut);
    delete[] pBuffer;
    delete[] pOut;
    return retValue;
}

string AesEncryptor::DecryptString(string strMessage) 
{
    int nLength = strMessage.length() / 2;

    std::string buffer(nLength, 0);
    char* pBuffer = &*buffer.begin();
    Hex2Byte(strMessage.c_str(), strMessage.length(), (unsigned char*)pBuffer);

    m_pEncryptor->InvCipher(pBuffer, nLength);
    
    return buffer;
}

void AesEncryptor::EncryptTxtFile(const char* inputFileName, const char* outputFileName)
{
	string strInfor;
	size_t fileLength = ReadFile(inputFileName, strInfor);
	if (fileLength <= 0)
	{
		return;
	}

    // Encrypt
    string strLine = EncryptString(strInfor);

    // Writefile
    ofstream ofs;
    ofs.open(outputFileName, ios_base::out | ios_base::binary);
    if (!ofs) {
        printf("AesEncryptor::EncryptTxtFile() - Open output file failed!");
        return ;
    }
    ofs << strLine;
    ofs.close();
}

void AesEncryptor::DecryptTxtFile(const char* inputFile, const char* outputFile) 
{
	// Writefile
	ofstream ofs;
	ofs.open(outputFile, ios_base::out | ios_base::binary);
	if (!ofs) {
		printf("AesEncryptor::DecryptTxtFile() - Open output file failed!");
		return;
	}
    DecryptTxtFile(inputFile, ofs);
    ofs.close();
}

void AesEncryptor::DecryptTxtFile(const char* inputFileName, ofstream& ofs)
{
	string strInfor;
    size_t fileLength = ReadFile(inputFileName, strInfor);
    if (fileLength <= 0)
    {
        return;
    }

	ofs << std::move(DecryptString(strInfor));
}
